# Copyright (c) 2024-present AI-Labs

from fastapi import APIRouter
from fastapi.responses import FileResponse

from diffusers import StableDiffusionXLPipeline, DPMSolverSinglestepScheduler, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline
from diffusers.utils import load_image
from .stable_diffusion import *

from datetime import datetime

import torch
import os
import numpy as np
import random
import uuid

from configs import config

"""
路由信息设置
"""
router = APIRouter(
    prefix='/image/stable_diffusion_xl',
    tags = ['图片生成']
)

device = "cuda"
dtype = torch.float16

# GiteeAI 平台部署加速
# text2imgPipe = StableDiffusionXLPipeline.from_pretrained("hf-models/sdxl-flash", torch_dtype=dtype).to(device)

# 下载到本地
# text2imgPipe = StableDiffusionXLPipeline.from_pretrained("models/sd-community/sdxl-flash", torch_dtype=dtype).to(device)

# 使用配置文件
# text2imgPipe = StableDiffusionXLPipeline.from_pretrained(config.service.stable_diffusion_xl.model_path, torch_dtype=dtype).to(device)
# text2imgPipe.scheduler = DPMSolverSinglestepScheduler.from_config(text2imgPipe.scheduler.config, timestep_spacing="trailing")

MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024


"""
生成响应信息，根据用户请求生成满足格式的响应数据，主要是将图片以不同的方式返回给用户
"""
def response_output(images, body):
    if body.output_format == "url":
        localdir = f"image/{datetime.now().strftime('%Y-%m-%d')}"
        os.makedirs(f"{config.setting.statics.path}/{localdir}", exist_ok=True)

        urls = []
        for image in images:
            localfile = f"{localdir}/{uuid.uuid4()}.png"
            image.save(f"{config.setting.statics.path}/{localfile}")
            urls.append({"url": f"{config.setting.statics.urls}/{localfile}"})

        return ImageResponse(status="success", format=body.output_format, data=urls, meta={})
    elif body.output_format == "base64":
        base64s = []
        for image in images:
            base64s.append({"b64_json": f"data:image;base64,{image_to_base64(image)}"})

        return ImageResponse(status="success", format=body.output_format, data=base64s, meta={})
    elif body.output_format == "file":
        localdir = f"image/{datetime.now().strftime('%Y-%m-%d')}"
        os.makedirs(f"{config.setting.statics.path}/{localdir}", exist_ok=True)

        localfiles = []
        for image in images:
            localfile = f"{localdir}/{uuid.uuid4()}.png"
            image.save(f"{config.setting.statics.path}/{localfile}")
            localfiles.append(localfile)

        return FileResponse(f"{config.setting.statics.path}/{localfiles[0]}", media_type="image/jpeg")

    return images


"""
如果启用了文本生成图片，则启用文本生成图片相关的 Pipe、服务等
"""
if config.service.stable_diffusion_xl.enable_text2img:
    text2imgPipe = StableDiffusionXLPipeline.from_pretrained(config.service.stable_diffusion_xl.model_path, torch_dtype=dtype).to(device)
    text2imgPipe.scheduler = DPMSolverSinglestepScheduler.from_config(text2imgPipe.scheduler.config, timestep_spacing="trailing")

    def text2img_generate(body: Text2ImageBody):
        real_prompt = generate_prompt(body.prompt)

        if body.randomize_seed:
            body.seed = random.randint(0, MAX_SEED)
        generator = torch.Generator().manual_seed(body.seed)
        # 使用文生图 Pipe 生成相应的图片
        images = text2imgPipe(
            prompt=real_prompt,
            negative_prompt=body.negative_prompt,
            guidance_scale=body.guidance_scale,
            num_inference_steps=body.num_inference_steps,
            width=body.width,
            height=body.width,
            num_images_per_prompt=body.samples,
            generator=generator,
            output_type="pil"
        ).images

        # 清除显存缓存
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()

        # 返回生成的图片
        return images


    """
    文生图服务
    """
    @router.post("/text2img")
    def text2img(body: Text2ImageBody):
        images = text2img_generate(body)
        return response_output(images, body)


"""
如果启用了图生图，则启用图生图相关的 Pipe、服务等
"""
if config.service.stable_diffusion_xl.enable_img2img:
    img2imgPipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(config.service.stable_diffusion_xl.model_path, torch_dtype=dtype).to(device)

    """
    图生图
    """
    def img2img_generate(body: Img2ImgBody):
        real_prompt = generate_prompt(body.prompt)

        if body.randomize_seed:
            body.seed = random.randint(0, MAX_SEED)
        generator = torch.Generator().manual_seed(body.seed)
        images = img2imgPipe(
            prompt=real_prompt,
            image=(base64_to_image(body.init_image.split(',')[1], "init_image") if body.init_image.startswith("data:") else load_image(body.init_image)).convert("RGB"),
            negative_prompt=body.negative_prompt,
            guidance_scale=body.guidance_scale,
            num_inference_steps=body.num_inference_steps,
            width=body.width,
            height=body.width,
            num_images_per_prompt=body.samples,
            generator=generator,
            output_type="pil"
        ).images

        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()

        return images


    """
    图生图服务
    """
    @router.post("/img2img")
    def img2img(body: Img2ImgBody):
        images = img2img_generate(body)
        return response_output(images, body)


"""
如果启用了局部重绘，则启动局部重绘相关的 Pipe、服务
"""
if config.service.stable_diffusion_xl.enable_inpainting:
    inpaintingPipe = StableDiffusionXLInpaintPipeline.from_pretrained(config.service.stable_diffusion_xl.inpainting_model_path, torch_dtype=dtype, variant="fp16").to(device)

    """
    局部重绘
    """
    def inpainting_generate(body: InpaintingBody):
        real_prompt = generate_prompt(body.prompt)

        if body.randomize_seed:
            body.seed = random.randint(0, MAX_SEED)
        generator = torch.Generator().manual_seed(body.seed)
        images = inpaintingPipe(
            prompt=real_prompt,
            image=(base64_to_image(body.init_image.split(',')[1], "init_image") if body.init_image.startswith("data:") else load_image(body.init_image)).convert("RGB"),
            mask_image=(base64_to_image(body.mask_image.split(',')[1], "mask_image") if body.mask_image.startswith("data:") else load_image(body.mask_image)).convert("RGB"),
            negative_prompt=body.negative_prompt,
            guidance_scale=body.guidance_scale,
            num_inference_steps=body.num_inference_steps,
            width=body.width,
            height=body.width,
            num_images_per_prompt=body.samples,
            generator=generator,
            output_type="pil"
        ).images

        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()

        return images


    """
    局部重绘服务
    """
    @router.post("/inpainting")
    def inpainting(body: InpaintingBody):
        images = inpainting_generate(body)
        return response_output(images, body)
